{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3eba4e9b",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Gradio让你的机器学习模型性感起来"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d314ec07",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "gradio是一个快速构建机器学习Web展示页面的开源Python库。\n",
    "\n",
    "只需要几行代码，就可以让你的机器学习模型从抽象晦涩的代码变成性感可爱的交互界面。\n",
    "\n",
    "让没有任何编程技能的用户也能够轻松使用和体验模型。\n",
    "\n",
    "它非常适合在模型迭代测试中快速获取用户反馈或者在汇报展示中进行使用，非常酷炫。\n",
    "\n",
    "相比另一个机器学习应用web展示库streamlit，gradio具有如下优势:\n",
    "\n",
    "* 便于分享：gradio可以在启动应用时设置share=True参数创建外部分享链接，可以直接在微信中分享给用户使用。\n",
    "\n",
    "* 方便调试：gradio可以在jupyter中直接展示页面，更加方便调试。\n",
    "\n",
    "\n",
    "\n",
    "大多数的gradio应用一般由如下最常用的基础模块构成。\n",
    "\n",
    "* 应用界面：gr.Interface(简易场景), gr.Blocks(定制化场景)\n",
    "\n",
    "* 输入输出：gr.Image(图像), gr.Textbox(文本框), gr.DataFrame(数据框), gr.Dropdown(下拉选项), gr.Number(数字), gr.Markdown, gr.Files\n",
    "\n",
    "* 控制组件：gr.Button(按钮)\n",
    "\n",
    "* 布局组件：gr.Tab(标签页), gr.Row(行布局), gr.Column(列布局)\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f9a3ff2",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "我们将由易到难通过5个范例来介绍gradio的使用方法。\n",
    "\n",
    "* hello world范例 （gr.Interface）\n",
    "\n",
    "* 文本分类 （gr.Interface）\n",
    "\n",
    "* 图片分类 （gr.Interface）\n",
    "\n",
    "* 目标检测  （gr.Blocks定制化）\n",
    "\n",
    "* 图片筛选器 （gr.Blocks）\n",
    "\n",
    "参考资料：\n",
    "* 官方教程：https://gradio.app/ \n",
    "\n",
    "* B站视频演示: https://www.bilibili.com/video/BV1BN411w7Ub/\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2dda4254-3a2a-4cfc-9df4-14c544360f02",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 一，Hello World (难度系数: ⭐️)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f839ea9a",
   "metadata": {
    "code_folding": [],
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "import gradio as gr\n",
    "\n",
    "def greet(name):\n",
    "    return \"Hello \" + name + \"!!\"\n",
    "\n",
    "demo = gr.Interface(fn=greet, inputs=\"text\", outputs=\"text\")\n",
    "gr.close_all()\n",
    "demo.launch(share=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46b19350",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 二，文本分类 (难度系数: ⭐️⭐️)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e5e57ff",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "#!pip install gradio, ultralytics, transformers, torchkeras "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdb83c57-36e3-42b3-909d-cd30f51d7ff2",
   "metadata": {
    "code_folding": [],
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "import gradio as gr \n",
    "from transformers import pipeline\n",
    "\n",
    "pipe = pipeline(\"text-classification\")\n",
    "\n",
    "def clf(text):\n",
    "    result = pipe(text)\n",
    "    label = result[0]['label']\n",
    "    score = result[0]['score']\n",
    "    res = {label:score,'POSITIVE' if label=='NEGATIVE' else 'NEGATIVE': 1-score}\n",
    "    return res \n",
    "\n",
    "demo = gr.Interface(fn=clf, inputs=\"text\", outputs=\"label\")\n",
    "gr.close_all()\n",
    "demo.launch(share=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e791d31f",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 三，图片分类 (难度系数: ⭐️⭐️⭐)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be5483f3",
   "metadata": {
    "code_folding": [
     11
    ],
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "import gradio as gr \n",
    "import pandas as pd \n",
    "from ultralytics import YOLO\n",
    "from skimage import data\n",
    "from PIL import Image\n",
    "\n",
    "model = YOLO('yolov8n-cls.pt')\n",
    "#prepare example image \n",
    "Image.fromarray(data.coffee()).save('coffee.jpeg') \n",
    "Image.fromarray(data.astronaut()).save('people.jpeg')\n",
    "Image.fromarray(data.cat()).save('cat.jpeg')\n",
    "def predict(img):\n",
    "    result = model.predict(source=img)\n",
    "    df = pd.Series(result[0].names).to_frame()\n",
    "    df.columns = ['names']\n",
    "    df['probs'] = result[0].probs\n",
    "    df = df.sort_values('probs',ascending=False)\n",
    "    res = dict(zip(df['names'],df['probs']))\n",
    "    return res\n",
    "gr.close_all() \n",
    "demo = gr.Interface(fn = predict,inputs = gr.Image(type='pil'), outputs = gr.Label(num_top_classes=5), \n",
    "                    examples = ['cat.jpeg','people.jpeg','coffee.jpeg'])\n",
    "demo.launch()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fe6b3b4",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 四，目标检测 (难度系数: ⭐️⭐️⭐⭐️)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c82da276",
   "metadata": {
    "code_folding": [],
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "import gradio as gr \n",
    "import pandas as pd \n",
    "from skimage import data\n",
    "from PIL import Image\n",
    "from torchkeras import plots \n",
    "from torchkeras.data import get_url_img\n",
    "from pathlib import Path\n",
    "from ultralytics import YOLO\n",
    "import ultralytics\n",
    "from ultralytics.yolo.data import utils \n",
    "\n",
    "model = YOLO('yolov8n.pt')\n",
    "\n",
    "#prepare example images\n",
    "Image.fromarray(data.coffee()).save('coffee.jpeg') \n",
    "Image.fromarray(data.astronaut()).save('people.jpeg')\n",
    "Image.fromarray(data.cat()).save('cat.jpeg')\n",
    "\n",
    "#load class_names\n",
    "yaml_path = str(Path(ultralytics.__file__).parent/'datasets/coco128.yaml') \n",
    "class_names = utils.yaml_load(yaml_path)['names']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c2c16f9",
   "metadata": {
    "code_folding": [],
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "def detect(img):\n",
    "    if isinstance(img,str):\n",
    "        img = get_url_img(img) if img.startswith('http') else Image.open(img).convert('RGB')\n",
    "    result = model.predict(source=img)\n",
    "    if len(result[0].boxes.boxes)>0:\n",
    "        vis = plots.plot_detection(img,boxes=result[0].boxes.boxes,\n",
    "                     class_names=class_names, min_score=0.2)\n",
    "    else:\n",
    "        vis = img\n",
    "    return vis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94ca2887",
   "metadata": {
    "code_folding": [
     28,
     40
    ],
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown(\"# yolov8目标检测演示\")\n",
    "\n",
    "    with gr.Tab(\"捕捉摄像头喔\"):\n",
    "        in_img = gr.Image(source='webcam',type='pil')\n",
    "        button = gr.Button(\"执行检测\",variant=\"primary\")\n",
    "\n",
    "        gr.Markdown(\"## 预测输出\")\n",
    "        out_img = gr.Image(type='pil')\n",
    "\n",
    "        button.click(detect,\n",
    "                     inputs=in_img, \n",
    "                     outputs=out_img)\n",
    "        \n",
    "    \n",
    "    with gr.Tab(\"选择测试图片\"):\n",
    "        files = ['people.jpeg','coffee.jpeg','cat.jpeg']\n",
    "        drop_down = gr.Dropdown(choices=files,value=files[0])\n",
    "        button = gr.Button(\"执行检测\",variant=\"primary\")\n",
    "        \n",
    "        \n",
    "        gr.Markdown(\"## 预测输出\")\n",
    "        out_img = gr.Image(type='pil')\n",
    "        \n",
    "        button.click(detect,\n",
    "                     inputs=drop_down, \n",
    "                     outputs=out_img)\n",
    "        \n",
    "    with gr.Tab(\"输入图片链接\"):\n",
    "        default_url = 'https://t7.baidu.com/it/u=3601447414,1764260638&fm=193&f=GIF'\n",
    "        url = gr.Textbox(value=default_url)\n",
    "        button = gr.Button(\"执行检测\",variant=\"primary\")\n",
    "        \n",
    "        gr.Markdown(\"## 预测输出\")\n",
    "        out_img = gr.Image(type='pil')\n",
    "        \n",
    "        button.click(detect,\n",
    "                     inputs=url, \n",
    "                     outputs=out_img)\n",
    "        \n",
    "    with gr.Tab(\"上传本地图片\"):\n",
    "        input_img = gr.Image(type='pil')\n",
    "        button = gr.Button(\"执行检测\",variant=\"primary\")\n",
    "        \n",
    "        gr.Markdown(\"## 预测输出\")\n",
    "        out_img = gr.Image(type='pil')\n",
    "        \n",
    "        button.click(detect,\n",
    "                     inputs=input_img, \n",
    "                     outputs=out_img)\n",
    "        \n",
    "\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40b105ae",
   "metadata": {
    "code_folding": [],
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "gr.close_all() \n",
    "demo.queue(concurrency_count=5)\n",
    "demo.launch()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d52b0d72",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 五，图片筛选器 (难度系数: ⭐️⭐️⭐⭐️⭐️)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "940da76f-6cb9-4a44-94a8-6bb6f91b4668",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "尽管gradio的设计初衷是为了快速创建机器学习用户交互页面。\n",
    "\n",
    "但实际上，通过组合gradio的各种组件，用户可以很方便地实现非常实用的各种应用小工具。\n",
    "\n",
    "例如:  数据分析展示dashboard,  数据标注工具, 制作一个小游戏界面等等。\n",
    "\n",
    "本范例我们将应用 gradio来构建一个图片筛选器，从百度爬取的一堆猫咪表情包中刷选一些我们喜欢的出来。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d02264a7-b87d-45dd-bcfd-70b7e2fafcde",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "#!pip install -U torchkeras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0a9a0ac-3456-417c-94b2-cc2889e305d5",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "import torchkeras \n",
    "from importlib import reload \n",
    "reload(torchkeras)\n",
    "from torchkeras.data import download_baidu_pictures \n",
    "download_baidu_pictures('猫咪表情包',100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6fa0ed3",
   "metadata": {
    "code_folding": [
     11,
     12,
     24,
     25,
     36,
     57
    ],
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "import gradio as gr\n",
    "from PIL import Image\n",
    "import time,os\n",
    "from pathlib import Path \n",
    "base_dir = '猫咪表情包'\n",
    "selected_dir = 'selected'\n",
    "files = [str(x) for x in \n",
    "         Path(base_dir).rglob('*.jp*g') \n",
    "         if 'checkpoint' not in str(x)]\n",
    "def show_img(path):\n",
    "    return Image.open(path)\n",
    "def fn_before(done,todo):\n",
    "    if done>=1:\n",
    "        done = done-1\n",
    "        todo = todo+1\n",
    "        \n",
    "    else:\n",
    "        done = done\n",
    "        todo = todo\n",
    "    \n",
    "    path = files[int(done-1)]\n",
    "    img = show_img(path)\n",
    "    \n",
    "    return done,todo,path,img\n",
    "def fn_next(done,todo):\n",
    "    if todo>=1:\n",
    "        done = done+1\n",
    "        todo = todo-1\n",
    "    else:\n",
    "        done = done\n",
    "        todo = todo\n",
    "        \n",
    "    path = files[int(done-1)]\n",
    "    img = show_img(path)\n",
    "\n",
    "    return done,todo,path,img\n",
    "def save_selected(img_path):\n",
    "    img_name = os.path.basename(img_path)\n",
    "    if img_path.startswith('http'):\n",
    "        img = get_image(img_path).convert('RGB')\n",
    "        img_path = 'tmp.jpg'\n",
    "        img.save(img_path)\n",
    "    selected_files = set(os.listdir(selected_dir))\n",
    "    \n",
    "    msg = ''\n",
    "    if img_name not in selected_files:\n",
    "        save_path = os.path.join(selected_dir,img_name)\n",
    "        img = Image.open(img_path)\n",
    "        img.save(save_path)\n",
    "        msg = 'selected images number = {}\\n'.format(len(selected_files)+1)+\\\n",
    "              'Save image sucessed!\\n'+\\\n",
    "              'Saved image name : {}'.format(img_name)\n",
    "    else:\n",
    "        msg = 'selected images number = {}\\n'.format(len(selected_files))+\\\n",
    "        \"Don't save duplicate images!\"\n",
    "    return msg \n",
    "def get_default_msg():\n",
    "    if not os.path.exists(selected_dir):\n",
    "        os.mkdir(selected_dir)\n",
    "    selected_files = set(os.listdir(selected_dir))\n",
    "    msg = 'Selected images number = {}\\n'.format(len(selected_files))\n",
    "    return msg "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70b778c1",
   "metadata": {
    "code_folding": [],
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "with gr.Blocks() as demo:\n",
    "    with gr.Row():\n",
    "        total = gr.Number(len(files),label='总数量')\n",
    "        with gr.Row(scale = 1):\n",
    "            bn_before = gr.Button(\"上一张\")\n",
    "            bn_next = gr.Button(\"下一张\")\n",
    "        with gr.Row(scale = 2):\n",
    "            done = gr.Number(0,label='已完成')\n",
    "            todo = gr.Number(len(files),label='待完成')\n",
    "    path = gr.Text(files[0],lines=1, label='当前图片路径')\n",
    "    feedback_button = gr.Button(\"选择图片\",variant=\"primary\")\n",
    "    msg = gr.TextArea(value=get_default_msg,lines=3,max_lines = 5)\n",
    "    img = gr.Image(value = show_img(files[0]),type='pil')\n",
    "    \n",
    "    bn_before.click(fn_before,\n",
    "                 inputs= [done,todo], \n",
    "                 outputs=[done,todo,path,img])\n",
    "    bn_next.click(fn_next,\n",
    "                 inputs= [done,todo], \n",
    "                 outputs=[done,todo,path,img])\n",
    "    feedback_button.click(save_selected,\n",
    "                         inputs = path,\n",
    "                         outputs = msg\n",
    "                         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab5c443a",
   "metadata": {
    "code_folding": [],
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    " demo.launch()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d46c6fb",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 六，huggingface托管 (难度系数: ⭐️⭐️)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05326557",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "为了便于向合作伙伴永久展示我们的模型App,可以将gradio的模型部署到 HuggingFace的 Space托管空间中，完全免费的哦。\n",
    "\n",
    "方法如下：\n",
    "\n",
    "1，注册huggingface账号：https://huggingface.co/join\n",
    "\n",
    "2，在space空间中创建项目：https://huggingface.co/spaces\n",
    "\n",
    "3，创建好的项目有一个Readme文档，可以根据说明操作，也可以手工编辑app.py和requirements.txt文件。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ffc4ca2",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "参考范例。\n",
    "\n",
    "https://huggingface.co/spaces/lyhue1991/yolov8_demo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69105ae3",
   "metadata": {},
   "source": [
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
